"""

A collection of basic statistical test function for Python.

"""

from turtle import shape
import numpy as np
import scipy.stats
from StatLearnPack.table import simple_table
from StatLearnPack.general import cor, acf
import abc


__all__ = ['vartest', 'proptest', 'ttest', 'Cor', 'chisqtest']


class test(object):
    """
    
    All test function should return this class

    """

class test_(metaclass=abc.ABCMeta):
    """
    This class is the parent class of all hypothetical classes
    """
    def __init__(self,p_value, stat, summary):
        self.p_value = p_value

class Levene:
    """
    levene class:
    this class provide the method of ”test for homogeneity of variance“
    Member methods:
                    fit
    You should initialize the significance level when instantiating the class
    >>>L = levene(alpha=0.01)
    """
    def __init__(self,alpha):
        self.alpha = alpha
        
    def fit(self,x1:np.array,x2:np.array,dis = True):
        H0 = False
        n1 = len(x1)
        n2 = len(x2)
        n  = n1 + n2
        mu1 = np.mean(x1)
        mu2 = np.mean(x2)
        z1 = np.abs(x1 - mu1)
        z2 = np.abs(x2 - mu2)
        mu_z1 = np.mean(z1)
        mu_z2 = np.mean(z2)
        mu_z = np.mean(np.vstack([z1,z2]))
        s1 = n1*np.sum((mu_z1 - mu_z)**2) + n2*np.sum((mu_z2 - mu_z)**2)
        s2 = np.sum((z1 - mu_z1)**2) + np.sum((z2 - mu_z2)**2)
        F = s1/(s2/(n-2))             
        p_leven_f = 1 - scipy.stats.f.cdf(F,1,n-2)
        if p_leven_f>self.alpha:
            H0 = True
        if dis == True:
            tb = pt.PrettyTable()
            tb.field_names = ["levene-F-stat","levene p-value","H0"]
            tb.add_row([np.round(F,4),np.round(p_leven_f,4),H0])
            print(tb)
        return [H0,F,p_leven_f]

class Cor(test):
    """
    This class implements the correlation coefficient test

    method : test(x: np.array, y: np.array, method='pearson')
    ----------
    """

    def __init__(self):
        self.__method = 'pearson'
        self.__methods = ('pearson','spearman')
    
    def calculate_cor(self, x: np.array, y: np.array, method="pearson"):
        """
        This function is used to calculate the correlation coefficient.
        Default correlation coefficient type: "pearson".
        we support pearson and spearman correlation coefficient.
        x : n * 1
        y : n * 1
        """
        
        if method in self.__methods:
            self.__method = method
        self.__x = x
        self.__y = y
        self.num_example = x.shape[0]
        if method == "pearson":
            self.r = self.__pearson(x, y)
        elif method =="spearman":
            self.r = self.__spearman(x, y)
        return self.r

    def __spearman(self, x: np.array, y: np.array):
        """spearman correlation coefficient"""
        sorter_x = np.argsort(x, kind='mergesort')
        sorter_y = np.argsort(y, kind='mergesort')
        d_sq_sum = np.sum(np.power(sorter_x - sorter_y,2))
        self.r = 1 - 6 * d_sq_sum / (self.num_example * (self.num_example**2 - 1))
        return self.r
        
    def __pearson(self, x: np.array, y: np.array):
        """Pearson correlation coefficient"""
        self.r = np.sum((x - np.mean(x)) * (y - np.mean(y))) / \
                   (np.sqrt(np.sum(np.power(x - np.mean(x), 2))) *
                    np.sqrt(np.sum(np.power(y - np.mean(y), 2))))

        return self.r
    
    def test(self, x, y, method='pearson'):
        """
        This function is used to do the test of Spearman's rank-correlation / pearson correlation coefficient.
        """
        if method not in self.__methods:
            raise ValueError('unknown method "{0}"'.format(method))

        self.calculate_cor(x, y,method=method)
        self.t = np.sqrt(self.num_example-2) * self.r / np.sqrt(1 - self.r**2)
        self.p_value = (1 - scipy.stats.t.cdf(np.abs(self.t),df=self.num_example -1)) * 2
        table = simple_table(title='correlation test',unit_width=30)
        table.add_line(style='double')
        table.add_row(['rho = '+ str(np.round(self.r, 4)), 'Number of obs = '+str(self.num_example), 'Type : '+str(self.__method)])
        table.add_row(['t statistic = ' + str(np.round(self.t, 4)), 'P-value = ' + str(np.round(self.p_value, 4)), "df = " + str(self.num_example - 2)])
        table.add_row(["alternate hypothesis is : ",' rho is not 0'," "])
        table.add_line(style='double')        
        return table.__str__()


def vartest(x, y, ratio=1, alternative='tow_sided', conf_level=0.95):
    """
    This function performs an F test to compare the variances of two samples from normal populations. 

    Parameters :
    ----------
    x, y : array_like
    ratio : your hypothesized ratio of the var(x) / var(y). ratio should greater tahn 0.
    
    """
    x, y = np.asarray(x), np.asarray(y)
    sig = 1 - conf_level
    n1, n2 = x.shape[0], y.shape[0]
    var_x, var_y = np.var(x, ddof=1), np.var(y, ddof=1)

    if ratio < 0:
        raise ValueError('ratio should greater tahn 0.')

    f = var_x / (ratio * var_y )
    lower = 0
    upper = np.inf

    if alternative not in ('tow_sided', 'less', 'greater'):
        raise ValueError('alternative should be tow_sided, less or greater')

    if alternative == 'tow_sided':
        alter_str = 'alternative hypothesis: true ratio of variances is not equal to {0}'.format(ratio)
        prob = 1 - 2 * abs(0.5 - f.cdf(f, n1-1, n2-1))
        lower = f * 1 / scipy.stats.f.ppf(conf_level+0.5*sig, n1-1, n2-1)
        upper = f * 1 / scipy.stats.f.ppf(0.5*sig, n1-1, n2-1)
    elif alternative == 'less':
        alter_str = 'alternative hypothesis: true ratio of variances is less than {0}'.format(ratio)
        prob = scipy.stats.f.cdf(f, n1-1, n2-1)
        upper = f * 1 / scipy.stats.f.ppf(sig, n1-1, n2-1)
    elif alternative == 'greater':
        alter_str = 'alternative hypothesis: true ratio of variances is greater than {0}'.format(ratio)
        prob = scipy.stats.f.sf(f, n1-1, n2-1)
    confint = [lower, upper]
    table = simple_table(title="F test to compare two variances", unit_width=12)
    table.add_line(style='double')
    table.add_row(['F = {0}'.format(np.round(f,4)),\
        'num df = {0}'.format(n1-1),'denom = {0}'.format(n2-1),'p-value = {0}'.format(np.round(prob,4))])
    table.add_row([alter_str])
    table.add_row(['95 percent confidence interval:', '({0}, {1})'.format(np.round(confint[0], 4),\
         np.round(confint[1], 4))])
    table.add_row(['sample estimates:','ratio of variances'])
    table.add_line(style='double')

    return table.__str__()

def levenetest(x, y, center='median'):
    """
    This function provide the method of "Levene's test for homogeneity of variance"

    Parameters :
    ---------- 
    *args : 

    Reference :
    [1] Brown, M. B. , and A. B. Forsythe . "Robust Tests for the Equality of Variances." \
        Publications of the American Statistical Association 69.346(1974):364-367.
    
    """

    if center not in ['mean', 'median', 'trimmed']:
        raise ValueError(" '{0}' not in 'mean', 'median' or 'trimmed'.".format(center))

    title = "Levene's Test for Homogeneity of Variance (center = {0})".format(center)
    table = simple_table(title)

    return table.__str__()





def ttest(x, y=None, mu=0, var_equal=True, conf_level= 0.95, paired=False, alternative='two_sided'):

    """
    Performs one or two sample t-tests for your data.

    In this function we provide 3 type of t-test.
    + One sample t-test
    + two independent sample t-test
    + two related samples t-test  
    
    Parameters :
    ----------
    x, y : array_like

    var_equal : bool 
    If you set `True`, that assumes population variances are equal. 
    `True` is the default option.

    paired : bool
    Set this option to True when you need to use two correlated samples t-test

    alternative : str alternative hypothesis. 
    Our available options are as follows:
    + 'two_sided': mean(x) - mean(y) != 0
    + 'less': mean(x) - mean(y) > 0
    + 'greater': mean(x) - mean(y) > 0
    `two_sided` is the default option 
   
    Returns :
    ----------
    Model Obeject

    Examples :
    ----------
    One-Sample ttest :

    >>> # One-Sample t-test
    >>> from StatLearnPack import stat_test
    >>> import numpy as np
    >>> x = np.array([1,2,3])
    >>> stat_test.ttest(x)

    Two independent sample t-test :
    If you input x and y, we do the Two independent sample t-test and use default option `var_equal`

    >>> # two independent sample t-test
    >>> from StatLearnPack import stat_test 
    >>> import numpy as np
    >>> x = np.array([1,2,3])
    >>> y = np.array([1,2,3])
    >>> print(stat_test.ttest(x,y,var_equal=True)) # var_equal = TRUE is default option

    If you set `paired=True`, we do the two related samples t-test.
    
    >>> # two related samples t-test
    >>> from StatLearnPack import stat_test 
    >>> import numpy as np
    >>> x = np.array([1,2,3,3,2,4,2])
    >>> y = np.array([1,2,3,4,5,7,8])
    >>> print(stat_test.ttest(x,y, paired=True))

    """

    x = np.asarray(x)
    y = np.asarray(y)

    if conf_level > 1 or conf_level < 0:
        raise ValueError("conf_level should greater than 0 and less than 1")
    
    def __check_test_type(x, y, var_equal, paired):
        """
        This function is used to check t test type.
        test_type: 1, 2, 3

        + 1 : One-Sample t-test
        + 2 : two independent sample t-test
        + 3 : two related samples t-test
        """
        test_type = 0
        if x.all() and y.all():
            if paired == True:
                test_type = 2
            else:
                test_type = 1
        else:
            test_type = 0
        return test_type
       
    def __ttest_common(t, diff, denom, df, conf_level, alternative):
        """
        This function contains common code for 3 type t test, 
        return confidence interval and P-value
        """    
        confint = [-np.inf, np.inf]
        sig = 0
        if alternative == 'two_sided':
            sig = (1 - conf_level) / 2
            prob = np.multiply(2, scipy.stats.t.sf(np.abs(t), df))
            lower = diff + scipy.stats.t.ppf(sig, df) * denom
            upper = diff - scipy.stats.t.ppf(sig, df) * denom
            confint[0] = lower
            confint[1] = upper
        elif alternative == 'less':
            sig = 1 - conf_level
            prob = np.multiply(2, scipy.stats.t.sf(np.abs(t), df))
            upper = diff - scipy.stats.t.ppf(sig, df) * denom
            confint[1] = upper

        elif alternative == 'greater':
            sig = 1 - conf_level
            prob = np.multiply(2, scipy.stats.t.sf(np.abs(t), df))
            lower = diff + scipy.stats.t.ppf(sig, df) * denom
            confint[0] = lower
        return confint, prob
    
    def __one_sample_ttest(x, mu, conf_level, alternative):
        """
        This function is used to do one-sample t-test.
        """
        n = x.shape[0]
        df = n - 1
        bar_x = np.mean(x)
        diff = bar_x - mu
        var_x = np.var(x, ddof=1)
        denom = np.sqrt(var_x / n)
        t = np.divide(diff, denom)
        confint, prob = __ttest_common(t, diff, denom, df, conf_level, alternative)
        return t, prob, df, n, confint, bar_x
    
    def __two_ind_ttest(x, y, mu, conf_level, alternative, var_equal):
        """
        This function is used to do two-sample t-test.

        """
        
        n1, n2 = x.shape[0], y.shape[0]
        x1, x2 = np.mean(x), np.mean(y)
        v1, v2 = np.var(x, ddof=1), np.var(y, ddof=1)
        vn1, vn2 = v1 / n1 , v2 / n2
        if var_equal == True:
            df = n1 + n2 -2
            v = ((n1-1) * v1 + (n2 - 1) * v2) / df
            denom = np.sqrt(v * (1.0 / n1 + 1.0 / n2))
        elif var_equal == False:
            denom = np.sqrt(vn1 + vn2)
            df = (vn1 + vn2)**2 / (vn1**2 / (n1 - 1) + vn2**2 / (n2 - 1))
        diff = x1 - x2 - mu
        t = np.divide(diff, denom)
        confint, prob = __ttest_common(t, diff, denom, df, conf_level, alternative)
        return t, prob, df, (n1,n2), confint, (x1, x2)

    def __paired_ttest(x, y, mu, conf_level, alternative):
        n1 = x.shape[0]
        n2 = y.shape[0]
        if n1 != n2:
            raise ValueError('n1 = {0} not equal to n2 = {1}. In Paired t-test, n1 should equal to n2'.format(n1, n2))
        
        n = n1
        df = n - 1
        x_y = x - y
        bar_x = np.mean(x_y)
        diff = bar_x - mu
        var_x = np.var(x_y, ddof=1)
        denom = np.sqrt(var_x / n)
        t = np.divide(diff, denom)
        confint, prob = __ttest_common(t, diff, denom, df, conf_level, alternative)
        return t, prob, df, n, confint, diff

    test_type = __check_test_type(x, y, var_equal, paired)

    if test_type == 0:
        table_title = 'One Sample t-test'
        t, prob, df, n, confint, bar_x = __one_sample_ttest(x, mu, conf_level, alternative)
    elif test_type == 1:
        table_title = 'Welch Two Sample t-test'
        t, prob, df, (n1, n2), confint, (x1, x2) = __two_ind_ttest(x, y, mu, conf_level, alternative, var_equal)
            
    elif test_type == 2:
        table_title = 'Paired t-test'
        t, prob, df, n, confint, bar_x = __paired_ttest(x, y, mu, conf_level, alternative)


    table = simple_table(title=table_title, unit_width=10)
    table.add_line(style='double')

    if test_type == 0:
        table.add_row(['t = {0}'.format(np.round(t, 4)) , 'df = {0}'.format(df), 
                        'p-value = {0}'.format(np.round(prob, 4)), 'num of obs = {0}'.format(n)])
        table.add_row(['{0} percent confidence interval : '.format(conf_level*100), '({0}, {1})'.format(np.round(confint[0],4),np.round(confint[1],4))])
        table.add_row(['sample estimates : ','mean of x', np.round(bar_x, 4)])
    elif test_type == 1:
        table.add_row(['t = {0}'.format(np.round(t, 4)) , 'df = {0}'.format(df), 
                       'p-value = {0}'.format(np.round(prob, 4)), 'num of obs = {0} {1}'.format(n1, n2)])
        table.add_row(['{0} percent confidence interval : '.format(conf_level*100), '({0}, {1})'.format(np.round(confint[0],4),np.round(confint[1],4))])
        table.add_row(['sample estimates : ', 'mean of x {0}'.format(np.round(x1, 4)), 'mean of y {0}'.format(np.round(x2, 4))])
    elif test_type:
        table.add_row(['t = {0}'.format(np.round(t, 4)) , 'df = {0}'.format(df), 
                        'p-value = {0}'.format(np.round(prob, 4)), 'num of obs = {0}'.format(n)])
        table.add_row(['{0} percent confidence interval : '.format(conf_level*100), '({0}, {1})'.format(np.round(confint[0],4),np.round(confint[1],4))])
        table.add_row(['sample estimates : ','mean of x', np.round(bar_x, 4)])
    table.add_line(style='double')

    return table.__str__()
    

def ztest(x):
    pass


def boxtest(x, lag, type='Ljung-Box'):
    """
    Box-Pierce and Ljung-Box Tests
    This function Provide Box-Pierce and Ljung-Box Q-test for residual autocorrelation

    Parameters :
    ----------
    x : array_like.
    lag : The lag order of the Q statistic. 
    type : `Box-Pierce` or `Ljung-Box Q-test`, default option is `Ljung-Box`.
    + 'Box-Pierce'
    + 'Ljung-Box'

    """
    x = np.asarray(x)
    n = x.shape[0]
    acf_list = acf(x, lag)
    if type == 'Ljung-Box':
        denom = np.arange(n-lag, n, 1)
        denom = denom[::-1]
        Q = n * (n+2) * np.sum(acf_list**2/denom)
        title = 'Box-Ljung test'
    elif type == 'Box-Pierce':
        Q = n * np.sum(np.power(acf_list, 2))
        title = 'Box-Pierce test'
        
    p_value = 1 - scipy.stats.chi2.cdf(Q, df=lag)
    table = simple_table(title=title)
    table.add_line(style="double")
    table.add_row(["X-squared = {}".format(np.round(Q, 4)), 'df = {}'.format(lag)])
    table.add_row(['p-value = {}'.format(np.round(p_value, 4))])
    table.add_line(style="double")

    return table.__str__()

def proptest(x, n, p=0.5, alternative='tow_sided', conf_level=0.95, correct=True):
    """
    1-sample proportions test
    This function is used to compare the observed proportions with the theoretical proportions.

    Parameters :
    ----------
    x : 1 dimension array_like 
    n : number of observation
    p :  hypothetical probability
    alternative : str alternative hypothesis. 
    conf_level : confidence level
    correct : if set `correct=True` , we will do continuity correction.

    Reference :
    ----------
    You can get more information from :

    [1] Wilson, E.B. (1927). Probable inference, the law of succession, and statistical inference. Journal of the American Statistical Association, 22, 209–212. doi: 10.2307/2276774.

    [2] Newcombe R.G. (1998). Two-Sided Confidence Intervals for the Single Proportion: Comparison of Seven Methods. Statistics in Medicine, 17, 857–872. doi: 10.1002/(SICI)1097-0258(19980430)17:8<857::AID-SIM777>3.0.CO;2-E.

    [3] Newcombe R.G. (1998). Interval Estimation for the Difference Between Independent Proportions: Comparison of Eleven Methods. Statistics in Medicine, 17, 873–890. doi: 10.1002/(SICI)1097-0258(19980430)17:8<873::AID-SIM779>3.0.CO;2-I.
    """

    if alternative not in ('tow_sided', 'less', 'greater'):
        raise ValueError('alternative should be tow_sided, less or greater')
    if conf_level > 1 or conf_level < 0:
        raise ValueError("conf_level should greater than 0 and less than 1")
        

    x = np.asarray(x)
    p = np.asarray(p)
    ndim = x.ndim
    if ndim > 1:
        raise ValueError('x should be 1-dimensional')
    m = x.shape
    alpha = 1 - conf_level
        
    if ndim == 0:
        sample_estimate = x/n
        lower, upper = 0, 1
        if correct == False:
            p_yate_u = sample_estimate
            p_yate_l = sample_estimate
            chisq = (np.sqrt(n) * (x/n - p) / np.sqrt(p*(1-p)))**2
            title = '1-sample proportions test without continuity correction'
        elif correct == True:
            title = '1-sample proportions test with continuity correction'
            chisq = (np.abs(x-n*p)-0.5)**2/(n*p) + (np.abs(x-n*(1-p))-0.5)**2/(n*(1-p))
            p_yate_u = sample_estimate + 0.5 / n
            p_yate_l = sample_estimate - 0.5 / n
        if alternative == 'tow_sided':
            z = scipy.stats.norm.isf(alpha/2)
            alternative_hypo = "true p is not equal to {0}".format(np.round(p, 4))
            p_value = scipy.stats.chi2.sf(chisq, df=1)
            upper = (2*n*p_yate_u + z**2 + z*np.sqrt(z**2+4*n*p_yate_u*(1-p_yate_u)))/(2*(n+z**2))
            lower = (2*n*p_yate_l + z**2 - z*np.sqrt(z**2+4*n*p_yate_l*(1-p_yate_l)))/(2*(n+z**2))
            alternative_hypo = 'true p is not equal to {0}'.format(np.round(p, 4))
        elif alternative == 'less':
            alternative_hypo = 'true p is less than to {0}'.format(np.round(p, 4))
            z = scipy.stats.norm.isf(alpha)
            upper = (2*n*p_yate_u + z**2 + z*np.sqrt(z**2+4*n*p_yate_u*(1-p_yate_u)))/(2*(n+z**2))
            lower = 0

        elif alternative == 'greater':
            alternative_hypo = 'true p is greater than to {0}'.format(np.round(p, 4))
            z = scipy.stats.norm.isf(alpha)
            upper = 1
            lower = (2*n*p_yate_l + z**2 - z*np.sqrt(z**2+4*n*p_yate_l*(1-p_yate_l)))/(2*(n+z**2))
        confint = [lower, upper]

    table = simple_table(title=title)
    table.add_line(style='double')
    table.add_row(['X-squared = {0}'.format(np.round(chisq, 4)), 'df = 1', 'p-value = {0}'.format(np.round(p_value, 4))])
    table.add_row(['alternative hypothesis : {0}'.format(alternative_hypo)])
    table.add_row(['95 percent confidence interval : [{0}, {1}]'.format(np.round(confint[0], 4), np.round(confint[1], 4))])
    table.add_row(['sample estimates : ', 'p = {0}'.format(np.round(sample_estimate, 4))])
    table.add_line(style='double')

    return table.__str__()


def chisqtest(x, y=None, p=None, correct=False, rescale_p = True):
    """
    Performs Chi-Squared Test for count data

    """
    
    x = np.asarray(x)
    if y != None:
        y = np.asarray(y)

    if x.ndim > 2:
        raise ValueError("x should be 1-dimensional or 2-dimensional array like data")
    if y!=None and x.ndim == 1 and y.ndim == 1 and len(x) != len(y):
        raise ValueError("x and y mast have same length")
    if np.nan in x or np.min(x) < 0:
        raise ValueError("all entries of 'x' must be nonnegative and finite")

    # get test type
    if x.ndim == 1 and y == None:
        title = "Chi-squared test for given probabilities"
        p = 1/len(x) * np.ones_like(x)
        n = np.sum(x)
        chisq = 0
        shape_x = x.shape
        for i in range(shape_x[0]):
            chisq = chisq + np.power((x[i] - n*p[i]), 2) / (n*p[i])
        df = shape_x[0]-1
    elif x.ndim == 2:
        title = "Pearson's Chi-squared test"
    elif x.ndim == 1 and y.ndim == 1:
        pass
    table = simple_table(title=title)
    table.add_line(style='double')
    table.add_row(['X-squared = {0}'.format(chisq), 'df = {0}'.format(df), 'p-value = '])
    table.add_line(style='double')
    return table.__str__()

